// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef FUZZTEST_FUZZTEST_INTERNAL_DOMAINS_FLAT_MAP_IMPL_H_
#define FUZZTEST_FUZZTEST_INTERNAL_DOMAINS_FLAT_MAP_IMPL_H_

#include <cstddef>
#include <optional>
#include <tuple>
#include <type_traits>

#include "absl/random/bit_gen_ref.h"
#include "absl/random/distributions.h"
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "absl/types/span.h"
#include "./fuzztest/internal/domains/domain_base.h"
#include "./fuzztest/internal/domains/serialization_helpers.h"
#include "./fuzztest/internal/logging.h"
#include "./fuzztest/internal/meta.h"
#include "./fuzztest/internal/serialization.h"
#include "./fuzztest/internal/status.h"
#include "./fuzztest/internal/type_support.h"

namespace fuzztest::internal {

// FlatMap takes a domain factory function (flat mapper) and an input domain
// for each parameter of the factory function. The output domain is what the
// flat mapper returns and the domain that FlatMap represents. I.e., the "output
// domain" is re-created dynamically, as it depends on values created by the
// input domains.
template <typename FlatMapper, typename... InputDomain>
using FlatMapOutputDomain = std::decay_t<
    std::invoke_result_t<FlatMapper, value_type_t<InputDomain>...>>;

template <typename FlatMapper, typename... InputDomain>
class FlatMapImpl
    : public domain_implementor::DomainBase<
          FlatMapImpl<FlatMapper, InputDomain...>,
          // The user value is the user value of the output domain.
          value_type_t<FlatMapOutputDomain<FlatMapper, InputDomain...>>,
          // The corpus value is a tuple where the first element is the corpus
          // value of the output domain, and the rest is the corpus value of the
          // input domains.
          std::tuple<
              corpus_type_t<FlatMapOutputDomain<FlatMapper, InputDomain...>>,
              corpus_type_t<InputDomain>...>> {
 public:
  using typename FlatMapImpl::DomainBase::corpus_type;
  using typename FlatMapImpl::DomainBase::value_type;

  FlatMapImpl() = default;
  explicit FlatMapImpl(FlatMapper flat_mapper, InputDomain... input_domains)
      : flat_mapper_(std::move(flat_mapper)),
        input_domains_(std::move(input_domains)...) {}

  corpus_type Init(absl::BitGenRef prng) {
    if (auto seed = this->MaybeGetRandomSeed(prng)) return *seed;
    auto input_corpus = std::apply(
        [&](auto&... input_domains) {
          return std::make_tuple(input_domains.Init(prng)...);
        },
        input_domains_);
    auto output_domain = GetOutputDomain(input_corpus);
    return std::tuple_cat(std::make_tuple(output_domain.Init(prng)),
                          input_corpus);
  }

  void Mutate(corpus_type& val, absl::BitGenRef prng, bool only_shrink) {
    // There is no way to tell whether the current output corpus value is
    // consistent with a new output domain generated by mutated inputs, so
    // mutating the inputs forces re-initialization of the output domain. This
    // means that, when shrinking, we cannot mutate the inputs, as
    // re-initializing would lose the "still crashing" output value.
    bool mutate_inputs = !only_shrink && absl::Bernoulli(prng, 0.1);
    if (mutate_inputs) {
      ApplyIndex<kNumInputValues>([&](auto... I) {
        // The first field of `val` is the output corpus value, so skip it.
        (std::get<I>(input_domains_)
             .Mutate(std::get<I + 1>(val), prng, only_shrink),
         ...);
      });
      std::get<0>(val) = GetOutputDomain(val).Init(prng);
      return;
    }
    // For simplicity, we create a new output domain each call to `Mutate`. This
    // means that stateful domains don't work, but this is currently a matter of
    // convenience, not correctness. For example, `Filter` won't automatically
    // find when something is too restrictive.
    // TODO(b/246423623): Support stateful domains.
    GetOutputDomain(val).Mutate(std::get<0>(val), prng, only_shrink);
  }

  value_type GetValue(const corpus_type& v) const {
    return GetOutputDomain(v).GetValue(std::get<0>(v));
  }

  std::optional<corpus_type> FromValue(const value_type&) const {
    // We cannot infer the input corpus from the output value, or even determine
    // from which output domain the output value came.
    return std::nullopt;
  }

  auto GetPrinter() const {
    return FlatMappedPrinter<FlatMapper, InputDomain...>{flat_mapper_,
                                                         input_domains_};
  }

  std::optional<corpus_type> ParseCorpus(const IRObject& obj) const {
    auto input_corpus = ParseWithDomainTuple(input_domains_, obj, /*skip=*/1);
    if (!input_corpus.has_value()) {
      return std::nullopt;
    }
    absl::Status input_values_validity = ValidateInputValues(*input_corpus);
    if (!input_values_validity.ok()) {
      absl::FPrintF(GetStderr(), "[!] %s", input_values_validity.message());
      return std::nullopt;
    }
    auto output_domain = GetOutputDomain(*input_corpus);
    // We know obj.Subs()[0] exists because ParseWithDomainTuple succeeded.
    auto output_corpus = output_domain.ParseCorpus((*obj.Subs())[0]);
    if (!output_corpus.has_value()) {
      return std::nullopt;
    }
    return std::tuple_cat(std::make_tuple(*output_corpus), *input_corpus);
  }

  IRObject SerializeCorpus(const corpus_type& v) const {
    auto domain =
        std::tuple_cat(std::make_tuple(GetOutputDomain(v)), input_domains_);
    return SerializeWithDomainTuple(domain, v);
  }

  absl::Status ValidateCorpusValue(const corpus_type& corpus_value) const {
    // Check input values first.
    absl::Status input_values_validity = ValidateInputValues(corpus_value);
    if (!input_values_validity.ok()) return input_values_validity;
    // Check the output value.
    return GetOutputDomain(corpus_value)
        .ValidateCorpusValue(std::get<0>(corpus_value));
  }

 private:
  // Returns the output domain for a `tuple` with or without the output value
  // as the leading element, and with the input values as the last
  // `kNumInputValues` elements.
  template <typename Tuple>
  FlatMapOutputDomain<FlatMapper, InputDomain...> GetOutputDomain(
      const Tuple& tuple) const {
    static_assert(is_tuple_v<Tuple> &&
                  std::tuple_size_v<Tuple> >= kNumInputValues);
    static constexpr size_t kOffset =
        std::tuple_size_v<Tuple> - kNumInputValues;
    return ApplyIndex<kNumInputValues>([&](auto... I) {
      // The first field of `tuple` may be the output corpus value, so skip it.
      return flat_mapper_(std::get<I>(input_domains_)
                              .GetValue(std::get<kOffset + I>(tuple))...);
    });
  }

  // Validates the input values for a `tuple` with or without the output value
  // as the leading element, and with the input values as the last
  // `kNumInputValues` elements.
  template <typename Tuple>
  absl::Status ValidateInputValues(const Tuple& tuple) const {
    static_assert(is_tuple_v<Tuple> &&
                  std::tuple_size_v<Tuple> >= kNumInputValues);
    static constexpr size_t kOffset =
        std::tuple_size_v<Tuple> - kNumInputValues;
    return ApplyIndex<kNumInputValues>([&](auto... I) {
      absl::Status input_values_validity = absl::OkStatus();
      (
          [&] {
            if (!input_values_validity.ok()) return;
            const absl::Status s =
                std::get<I>(input_domains_)
                    .ValidateCorpusValue(std::get<kOffset + I>(tuple));
            input_values_validity =
                Prefix(s, "Invalid value for FlatMap()-ed domain");
          }(),
          ...);
      return input_values_validity;
    });
  }

  static constexpr size_t kNumInputValues = sizeof...(InputDomain);
  FlatMapper flat_mapper_;
  std::tuple<InputDomain...> input_domains_;
};

}  // namespace fuzztest::internal

#endif  // FUZZTEST_FUZZTEST_INTERNAL_DOMAINS_FLAT_MAP_IMPL_H_
